{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.5"
    },
    "colab": {
      "name": "2-4-5_SSD_model_forward.ipynb",
      "provenance": [],
      "collapsed_sections": []
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F94O4yYxGw8M",
        "colab_type": "text"
      },
      "source": [
        "# 2.4 ネットワークモデルの実装、2.5 順伝搬関数の実装\n",
        "\n",
        "本ファイルでは、SSDのネットワークモデルと順伝搬forward関数を作成します。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JPc7kqyjGw8Q",
        "colab_type": "text"
      },
      "source": [
        "# 2.4 学習目標\n",
        "\n",
        "1.\tSSDのネットワークモデルを構築している4つのモジュールを把握する\n",
        "2.\tSSDのネットワークモデルを作成できるようになる\n",
        "3.\tSSDで使用する様々な大きさのデフォルトボックスの実装方法を理解する\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lcaeSHrlGw8S",
        "colab_type": "text"
      },
      "source": [
        "# 2.5 学習目標\n",
        "\n",
        "1.\tNon-Maximum Suppressionを理解する\n",
        "2.\tSSDの推論時に使用するDetectクラスの順伝搬を理解する\n",
        "3.\tSSDの順伝搬を実装できるようになる\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eP5qFbGsGw8T",
        "colab_type": "text"
      },
      "source": [
        "# 事前準備\n",
        "\n",
        "\n",
        "とくになし"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5SezJNFsGw8U",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# パッケージのimport\n",
        "from math import sqrt\n",
        "from itertools import product\n",
        "\n",
        "import pandas as pd\n",
        "import torch\n",
        "from torch.autograd import Function\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.nn.init as init"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_dzZn2Y3Gw8a",
        "colab_type": "text"
      },
      "source": [
        "# vggモジュールを実装"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "seziJ76uGw8b",
        "colab_type": "code",
        "colab": {},
        "outputId": "a23d228c-e2fb-469b-c3b5-88ea6e5fb5df"
      },
      "source": [
        "# 35層にわたる、vggモジュールを作成\n",
        "def make_vgg():\n",
        "    layers = []\n",
        "    in_channels = 3  # 色チャネル数\n",
        "\n",
        "    # vggモジュールで使用する畳み込み層やマックスプーリングのチャネル数\n",
        "    cfg = [64, 64, 'M', 128, 128, 'M', 256, 256,\n",
        "           256, 'MC', 512, 512, 512, 'M', 512, 512, 512]\n",
        "\n",
        "    for v in cfg:\n",
        "        if v == 'M':\n",
        "            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]\n",
        "        elif v == 'MC':\n",
        "            # ceilは出力サイズを、計算結果（float）に対して、切り上げで整数にするモード\n",
        "            # デフォルトでは出力サイズを計算結果（float）に対して、切り下げで整数にするfloorモード\n",
        "            layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]\n",
        "        else:\n",
        "            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)\n",
        "            layers += [conv2d, nn.ReLU(inplace=True)]\n",
        "            in_channels = v\n",
        "\n",
        "    pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)\n",
        "    conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)\n",
        "    conv7 = nn.Conv2d(1024, 1024, kernel_size=1)\n",
        "    layers += [pool5, conv6,\n",
        "               nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)]\n",
        "    return nn.ModuleList(layers)\n",
        "\n",
        "\n",
        "# 動作確認\n",
        "vgg_test = make_vgg()\n",
        "print(vgg_test)\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "ModuleList(\n",
            "  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (1): ReLU(inplace)\n",
            "  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (3): ReLU(inplace)\n",
            "  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
            "  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (6): ReLU(inplace)\n",
            "  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (8): ReLU(inplace)\n",
            "  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
            "  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (11): ReLU(inplace)\n",
            "  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (13): ReLU(inplace)\n",
            "  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (15): ReLU(inplace)\n",
            "  (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)\n",
            "  (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (18): ReLU(inplace)\n",
            "  (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (20): ReLU(inplace)\n",
            "  (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (22): ReLU(inplace)\n",
            "  (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
            "  (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (25): ReLU(inplace)\n",
            "  (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (27): ReLU(inplace)\n",
            "  (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (29): ReLU(inplace)\n",
            "  (30): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
            "  (31): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))\n",
            "  (32): ReLU(inplace)\n",
            "  (33): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))\n",
            "  (34): ReLU(inplace)\n",
            ")\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9T7XgWrKGw8g",
        "colab_type": "text"
      },
      "source": [
        "# extrasモジュールを実装"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2i04NYMoGw8h",
        "colab_type": "code",
        "colab": {},
        "outputId": "af5cf9cd-4011-4d63-fab5-7ba4eed77075"
      },
      "source": [
        "# 8層にわたる、extrasモジュールを作成\n",
        "def make_extras():\n",
        "    layers = []\n",
        "    in_channels = 1024  # vggモジュールから出力された、extraに入力される画像チャネル数\n",
        "\n",
        "    # extraモジュールの畳み込み層のチャネル数を設定するコンフィギュレーション\n",
        "    cfg = [256, 512, 128, 256, 128, 256, 128, 256]\n",
        "\n",
        "    layers += [nn.Conv2d(in_channels, cfg[0], kernel_size=(1))]\n",
        "    layers += [nn.Conv2d(cfg[0], cfg[1], kernel_size=(3), stride=2, padding=1)]\n",
        "    layers += [nn.Conv2d(cfg[1], cfg[2], kernel_size=(1))]\n",
        "    layers += [nn.Conv2d(cfg[2], cfg[3], kernel_size=(3), stride=2, padding=1)]\n",
        "    layers += [nn.Conv2d(cfg[3], cfg[4], kernel_size=(1))]\n",
        "    layers += [nn.Conv2d(cfg[4], cfg[5], kernel_size=(3))]\n",
        "    layers += [nn.Conv2d(cfg[5], cfg[6], kernel_size=(1))]\n",
        "    layers += [nn.Conv2d(cfg[6], cfg[7], kernel_size=(3))]\n",
        "    \n",
        "    # 活性化関数のReLUは今回はSSDモデルの順伝搬のなかで用意することにし、\n",
        "    # extraモジュールでは用意していません\n",
        "\n",
        "    return nn.ModuleList(layers)\n",
        "\n",
        "\n",
        "# 動作確認\n",
        "extras_test = make_extras()\n",
        "print(extras_test)\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "ModuleList(\n",
            "  (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))\n",
            "  (1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
            "  (2): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))\n",
            "  (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
            "  (4): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
            "  (5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))\n",
            "  (6): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
            "  (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))\n",
            ")\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lpMwWQfZGw8m",
        "colab_type": "text"
      },
      "source": [
        "# locモジュールとconfモジュールを実装"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-vXKr2WnGw8o",
        "colab_type": "code",
        "colab": {},
        "outputId": "f085d3ef-d4ea-48ff-cca0-c67aeafa8b63"
      },
      "source": [
        "# デフォルトボックスのオフセットを出力するloc_layers、\n",
        "# デフォルトボックスに対する各クラスの信頼度confidenceを出力するconf_layersを作成\n",
        "\n",
        "\n",
        "def make_loc_conf(num_classes=21, bbox_aspect_num=[4, 6, 6, 6, 4, 4]):\n",
        "\n",
        "    loc_layers = []\n",
        "    conf_layers = []\n",
        "\n",
        "    # VGGの22層目、conv4_3（source1）に対する畳み込み層\n",
        "    loc_layers += [nn.Conv2d(512, bbox_aspect_num[0]\n",
        "                             * 4, kernel_size=3, padding=1)]\n",
        "    conf_layers += [nn.Conv2d(512, bbox_aspect_num[0]\n",
        "                              * num_classes, kernel_size=3, padding=1)]\n",
        "\n",
        "    # VGGの最終層（source2）に対する畳み込み層\n",
        "    loc_layers += [nn.Conv2d(1024, bbox_aspect_num[1]\n",
        "                             * 4, kernel_size=3, padding=1)]\n",
        "    conf_layers += [nn.Conv2d(1024, bbox_aspect_num[1]\n",
        "                              * num_classes, kernel_size=3, padding=1)]\n",
        "\n",
        "    # extraの（source3）に対する畳み込み層\n",
        "    loc_layers += [nn.Conv2d(512, bbox_aspect_num[2]\n",
        "                             * 4, kernel_size=3, padding=1)]\n",
        "    conf_layers += [nn.Conv2d(512, bbox_aspect_num[2]\n",
        "                              * num_classes, kernel_size=3, padding=1)]\n",
        "\n",
        "    # extraの（source4）に対する畳み込み層\n",
        "    loc_layers += [nn.Conv2d(256, bbox_aspect_num[3]\n",
        "                             * 4, kernel_size=3, padding=1)]\n",
        "    conf_layers += [nn.Conv2d(256, bbox_aspect_num[3]\n",
        "                              * num_classes, kernel_size=3, padding=1)]\n",
        "\n",
        "    # extraの（source5）に対する畳み込み層\n",
        "    loc_layers += [nn.Conv2d(256, bbox_aspect_num[4]\n",
        "                             * 4, kernel_size=3, padding=1)]\n",
        "    conf_layers += [nn.Conv2d(256, bbox_aspect_num[4]\n",
        "                              * num_classes, kernel_size=3, padding=1)]\n",
        "\n",
        "    # extraの（source6）に対する畳み込み層\n",
        "    loc_layers += [nn.Conv2d(256, bbox_aspect_num[5]\n",
        "                             * 4, kernel_size=3, padding=1)]\n",
        "    conf_layers += [nn.Conv2d(256, bbox_aspect_num[5]\n",
        "                              * num_classes, kernel_size=3, padding=1)]\n",
        "\n",
        "    return nn.ModuleList(loc_layers), nn.ModuleList(conf_layers)\n",
        "\n",
        "\n",
        "# 動作確認\n",
        "loc_test, conf_test = make_loc_conf()\n",
        "print(loc_test)\n",
        "print(conf_test)\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "ModuleList(\n",
            "  (0): Conv2d(512, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (1): Conv2d(1024, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (2): Conv2d(512, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (3): Conv2d(256, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (4): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (5): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            ")\n",
            "ModuleList(\n",
            "  (0): Conv2d(512, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (1): Conv2d(1024, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (2): Conv2d(512, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (3): Conv2d(256, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (4): Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  (5): Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            ")\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tOgBNY7FGw8s",
        "colab_type": "text"
      },
      "source": [
        "# L2Norm層を実装"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8EwzMY9fGw8u",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# convC4_3からの出力をscale=20のL2Normで正規化する層\n",
        "class L2Norm(nn.Module):\n",
        "    def __init__(self, input_channels=512, scale=20):\n",
        "        super(L2Norm, self).__init__()  # 親クラスのコンストラクタ実行\n",
        "        self.weight = nn.Parameter(torch.Tensor(input_channels))\n",
        "        self.scale = scale  # 係数weightの初期値として設定する値\n",
        "        self.reset_parameters()  # パラメータの初期化\n",
        "        self.eps = 1e-10\n",
        "\n",
        "    def reset_parameters(self):\n",
        "        '''結合パラメータを大きさscaleの値にする初期化を実行'''\n",
        "        init.constant_(self.weight, self.scale)  # weightの値がすべてscale（=20）になる\n",
        "\n",
        "    def forward(self, x):\n",
        "        '''38×38の特徴量に対して、512チャネルにわたって2乗和のルートを求めた\n",
        "        38×38個の値を使用し、各特徴量を正規化してから係数をかけ算する層'''\n",
        "\n",
        "        # 各チャネルにおける38×38個の特徴量のチャネル方向の2乗和を計算し、\n",
        "        # さらにルートを求め、割り算して正規化する\n",
        "        # normのテンソルサイズはtorch.Size([batch_num, 1, 38, 38])になります\n",
        "        norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()+self.eps\n",
        "        x = torch.div(x, norm)\n",
        "\n",
        "        # 係数をかける。係数はチャネルごとに1つで、512個の係数を持つ\n",
        "        # self.weightのテンソルサイズはtorch.Size([512])なので\n",
        "        # torch.Size([batch_num, 512, 38, 38])まで変形します\n",
        "        weights = self.weight.unsqueeze(\n",
        "            0).unsqueeze(2).unsqueeze(3).expand_as(x)\n",
        "        out = weights * x\n",
        "\n",
        "        return out\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0zN8JIPvGw8x",
        "colab_type": "text"
      },
      "source": [
        "# デフォルトボックスを実装"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Q6K7TmuLGw8z",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# デフォルトボックスを出力するクラス\n",
        "class DBox(object):\n",
        "    def __init__(self, cfg):\n",
        "        super(DBox, self).__init__()\n",
        "\n",
        "        # 初期設定\n",
        "        self.image_size = cfg['input_size']  # 画像サイズの300\n",
        "        # [38, 19, …] 各sourceの特徴量マップのサイズ\n",
        "        self.feature_maps = cfg['feature_maps']\n",
        "        self.num_priors = len(cfg[\"feature_maps\"])  # sourceの個数=6\n",
        "        self.steps = cfg['steps']  # [8, 16, …] DBoxのピクセルサイズ\n",
        "        \n",
        "        self.min_sizes = cfg['min_sizes']\n",
        "        # [30, 60, …] 小さい正方形のDBoxのピクセルサイズ（正確には面積）\n",
        "        \n",
        "        self.max_sizes = cfg['max_sizes']\n",
        "        # [60, 111, …] 大きい正方形のDBoxのピクセルサイズ（正確には面積）\n",
        "        \n",
        "        self.aspect_ratios = cfg['aspect_ratios']  # 長方形のDBoxのアスペクト比\n",
        "\n",
        "    def make_dbox_list(self):\n",
        "        '''DBoxを作成する'''\n",
        "        mean = []\n",
        "        # 'feature_maps': [38, 19, 10, 5, 3, 1]\n",
        "        for k, f in enumerate(self.feature_maps):\n",
        "            for i, j in product(range(f), repeat=2):  # fまでの数で2ペアの組み合わせを作る　f_P_2 個\n",
        "                # 特徴量の画像サイズ\n",
        "                # 300 / 'steps': [8, 16, 32, 64, 100, 300],\n",
        "                f_k = self.image_size / self.steps[k]\n",
        "\n",
        "                # DBoxの中心座標 x,y　ただし、0～1で規格化している\n",
        "                cx = (j + 0.5) / f_k\n",
        "                cy = (i + 0.5) / f_k\n",
        "\n",
        "                # アスペクト比1の小さいDBox [cx,cy, width, height]\n",
        "                # 'min_sizes': [30, 60, 111, 162, 213, 264]\n",
        "                s_k = self.min_sizes[k]/self.image_size\n",
        "                mean += [cx, cy, s_k, s_k]\n",
        "\n",
        "                # アスペクト比1の大きいDBox [cx,cy, width, height]\n",
        "                # 'max_sizes': [60, 111, 162, 213, 264, 315],\n",
        "                s_k_prime = sqrt(s_k * (self.max_sizes[k]/self.image_size))\n",
        "                mean += [cx, cy, s_k_prime, s_k_prime]\n",
        "\n",
        "                # その他のアスペクト比のdefBox [cx,cy, width, height]\n",
        "                for ar in self.aspect_ratios[k]:\n",
        "                    mean += [cx, cy, s_k*sqrt(ar), s_k/sqrt(ar)]\n",
        "                    mean += [cx, cy, s_k/sqrt(ar), s_k*sqrt(ar)]\n",
        "\n",
        "        # DBoxをテンソルに変換 torch.Size([8732, 4])\n",
        "        output = torch.Tensor(mean).view(-1, 4)\n",
        "\n",
        "        # DBoxが画像の外にはみ出るのを防ぐため、大きさを最小0、最大1にする\n",
        "        output.clamp_(max=1, min=0)\n",
        "\n",
        "        return output\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rMQnyO-VGw83",
        "colab_type": "code",
        "colab": {},
        "outputId": "48ff0468-4eac-4d33-ae90-92ffb86fa8a4"
      },
      "source": [
        "# 動作の確認\n",
        "\n",
        "# SSD300の設定\n",
        "ssd_cfg = {\n",
        "    'num_classes': 21,  # 背景クラスを含めた合計クラス数\n",
        "    'input_size': 300,  # 画像の入力サイズ\n",
        "    'bbox_aspect_num': [4, 6, 6, 6, 4, 4],  # 出力するDBoxのアスペクト比の種類\n",
        "    'feature_maps': [38, 19, 10, 5, 3, 1],  # 各sourceの画像サイズ\n",
        "    'steps': [8, 16, 32, 64, 100, 300],  # DBOXの大きさを決める\n",
        "    'min_sizes': [30, 60, 111, 162, 213, 264],  # DBOXの大きさを決める\n",
        "    'max_sizes': [60, 111, 162, 213, 264, 315],  # DBOXの大きさを決める\n",
        "    'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],\n",
        "}\n",
        "\n",
        "# DBox作成\n",
        "dbox = DBox(ssd_cfg)\n",
        "dbox_list = dbox.make_dbox_list()\n",
        "\n",
        "# DBoxの出力を確認する\n",
        "pd.DataFrame(dbox_list.numpy())\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>0</th>\n",
              "      <th>1</th>\n",
              "      <th>2</th>\n",
              "      <th>3</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.100000</td>\n",
              "      <td>0.100000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.141421</td>\n",
              "      <td>0.141421</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.141421</td>\n",
              "      <td>0.070711</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.070711</td>\n",
              "      <td>0.141421</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>0.040000</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.100000</td>\n",
              "      <td>0.100000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td>0.040000</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.141421</td>\n",
              "      <td>0.141421</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>0.040000</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.141421</td>\n",
              "      <td>0.070711</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>0.040000</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.070711</td>\n",
              "      <td>0.141421</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>0.066667</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.100000</td>\n",
              "      <td>0.100000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9</th>\n",
              "      <td>0.066667</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.141421</td>\n",
              "      <td>0.141421</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>10</th>\n",
              "      <td>0.066667</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.141421</td>\n",
              "      <td>0.070711</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>11</th>\n",
              "      <td>0.066667</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.070711</td>\n",
              "      <td>0.141421</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>12</th>\n",
              "      <td>0.093333</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.100000</td>\n",
              "      <td>0.100000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>13</th>\n",
              "      <td>0.093333</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.141421</td>\n",
              "      <td>0.141421</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>14</th>\n",
              "      <td>0.093333</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.141421</td>\n",
              "      <td>0.070711</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>15</th>\n",
              "      <td>0.093333</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.070711</td>\n",
              "      <td>0.141421</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>16</th>\n",
              "      <td>0.120000</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.100000</td>\n",
              "      <td>0.100000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>17</th>\n",
              "      <td>0.120000</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.141421</td>\n",
              "      <td>0.141421</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>18</th>\n",
              "      <td>0.120000</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.141421</td>\n",
              "      <td>0.070711</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>19</th>\n",
              "      <td>0.120000</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.070711</td>\n",
              "      <td>0.141421</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>20</th>\n",
              "      <td>0.146667</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.100000</td>\n",
              "      <td>0.100000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>21</th>\n",
              "      <td>0.146667</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.141421</td>\n",
              "      <td>0.141421</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>22</th>\n",
              "      <td>0.146667</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.141421</td>\n",
              "      <td>0.070711</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>23</th>\n",
              "      <td>0.146667</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.070711</td>\n",
              "      <td>0.141421</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>24</th>\n",
              "      <td>0.173333</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.100000</td>\n",
              "      <td>0.100000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>25</th>\n",
              "      <td>0.173333</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.141421</td>\n",
              "      <td>0.141421</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>26</th>\n",
              "      <td>0.173333</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.141421</td>\n",
              "      <td>0.070711</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>27</th>\n",
              "      <td>0.173333</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.070711</td>\n",
              "      <td>0.141421</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>28</th>\n",
              "      <td>0.200000</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.100000</td>\n",
              "      <td>0.100000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>29</th>\n",
              "      <td>0.200000</td>\n",
              "      <td>0.013333</td>\n",
              "      <td>0.141421</td>\n",
              "      <td>0.141421</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>...</th>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8702</th>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.166667</td>\n",
              "      <td>1.000000</td>\n",
              "      <td>0.502046</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8703</th>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.166667</td>\n",
              "      <td>0.502046</td>\n",
              "      <td>1.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8704</th>\n",
              "      <td>0.166667</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.710000</td>\n",
              "      <td>0.710000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8705</th>\n",
              "      <td>0.166667</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.790443</td>\n",
              "      <td>0.790443</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8706</th>\n",
              "      <td>0.166667</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>1.000000</td>\n",
              "      <td>0.502046</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8707</th>\n",
              "      <td>0.166667</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.502046</td>\n",
              "      <td>1.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8708</th>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.710000</td>\n",
              "      <td>0.710000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8709</th>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.790443</td>\n",
              "      <td>0.790443</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8710</th>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>1.000000</td>\n",
              "      <td>0.502046</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8711</th>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.502046</td>\n",
              "      <td>1.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8712</th>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.710000</td>\n",
              "      <td>0.710000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8713</th>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.790443</td>\n",
              "      <td>0.790443</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8714</th>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>1.000000</td>\n",
              "      <td>0.502046</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8715</th>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.502046</td>\n",
              "      <td>1.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8716</th>\n",
              "      <td>0.166667</td>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.710000</td>\n",
              "      <td>0.710000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8717</th>\n",
              "      <td>0.166667</td>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.790443</td>\n",
              "      <td>0.790443</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8718</th>\n",
              "      <td>0.166667</td>\n",
              "      <td>0.833333</td>\n",
              "      <td>1.000000</td>\n",
              "      <td>0.502046</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8719</th>\n",
              "      <td>0.166667</td>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.502046</td>\n",
              "      <td>1.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8720</th>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.710000</td>\n",
              "      <td>0.710000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8721</th>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.790443</td>\n",
              "      <td>0.790443</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8722</th>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.833333</td>\n",
              "      <td>1.000000</td>\n",
              "      <td>0.502046</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8723</th>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.502046</td>\n",
              "      <td>1.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8724</th>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.710000</td>\n",
              "      <td>0.710000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8725</th>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.790443</td>\n",
              "      <td>0.790443</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8726</th>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.833333</td>\n",
              "      <td>1.000000</td>\n",
              "      <td>0.502046</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8727</th>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.833333</td>\n",
              "      <td>0.502046</td>\n",
              "      <td>1.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8728</th>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.880000</td>\n",
              "      <td>0.880000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8729</th>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.961249</td>\n",
              "      <td>0.961249</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8730</th>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>1.000000</td>\n",
              "      <td>0.622254</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8731</th>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.622254</td>\n",
              "      <td>1.000000</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "<p>8732 rows × 4 columns</p>\n",
              "</div>"
            ],
            "text/plain": [
              "             0         1         2         3\n",
              "0     0.013333  0.013333  0.100000  0.100000\n",
              "1     0.013333  0.013333  0.141421  0.141421\n",
              "2     0.013333  0.013333  0.141421  0.070711\n",
              "3     0.013333  0.013333  0.070711  0.141421\n",
              "4     0.040000  0.013333  0.100000  0.100000\n",
              "5     0.040000  0.013333  0.141421  0.141421\n",
              "6     0.040000  0.013333  0.141421  0.070711\n",
              "7     0.040000  0.013333  0.070711  0.141421\n",
              "8     0.066667  0.013333  0.100000  0.100000\n",
              "9     0.066667  0.013333  0.141421  0.141421\n",
              "10    0.066667  0.013333  0.141421  0.070711\n",
              "11    0.066667  0.013333  0.070711  0.141421\n",
              "12    0.093333  0.013333  0.100000  0.100000\n",
              "13    0.093333  0.013333  0.141421  0.141421\n",
              "14    0.093333  0.013333  0.141421  0.070711\n",
              "15    0.093333  0.013333  0.070711  0.141421\n",
              "16    0.120000  0.013333  0.100000  0.100000\n",
              "17    0.120000  0.013333  0.141421  0.141421\n",
              "18    0.120000  0.013333  0.141421  0.070711\n",
              "19    0.120000  0.013333  0.070711  0.141421\n",
              "20    0.146667  0.013333  0.100000  0.100000\n",
              "21    0.146667  0.013333  0.141421  0.141421\n",
              "22    0.146667  0.013333  0.141421  0.070711\n",
              "23    0.146667  0.013333  0.070711  0.141421\n",
              "24    0.173333  0.013333  0.100000  0.100000\n",
              "25    0.173333  0.013333  0.141421  0.141421\n",
              "26    0.173333  0.013333  0.141421  0.070711\n",
              "27    0.173333  0.013333  0.070711  0.141421\n",
              "28    0.200000  0.013333  0.100000  0.100000\n",
              "29    0.200000  0.013333  0.141421  0.141421\n",
              "...        ...       ...       ...       ...\n",
              "8702  0.833333  0.166667  1.000000  0.502046\n",
              "8703  0.833333  0.166667  0.502046  1.000000\n",
              "8704  0.166667  0.500000  0.710000  0.710000\n",
              "8705  0.166667  0.500000  0.790443  0.790443\n",
              "8706  0.166667  0.500000  1.000000  0.502046\n",
              "8707  0.166667  0.500000  0.502046  1.000000\n",
              "8708  0.500000  0.500000  0.710000  0.710000\n",
              "8709  0.500000  0.500000  0.790443  0.790443\n",
              "8710  0.500000  0.500000  1.000000  0.502046\n",
              "8711  0.500000  0.500000  0.502046  1.000000\n",
              "8712  0.833333  0.500000  0.710000  0.710000\n",
              "8713  0.833333  0.500000  0.790443  0.790443\n",
              "8714  0.833333  0.500000  1.000000  0.502046\n",
              "8715  0.833333  0.500000  0.502046  1.000000\n",
              "8716  0.166667  0.833333  0.710000  0.710000\n",
              "8717  0.166667  0.833333  0.790443  0.790443\n",
              "8718  0.166667  0.833333  1.000000  0.502046\n",
              "8719  0.166667  0.833333  0.502046  1.000000\n",
              "8720  0.500000  0.833333  0.710000  0.710000\n",
              "8721  0.500000  0.833333  0.790443  0.790443\n",
              "8722  0.500000  0.833333  1.000000  0.502046\n",
              "8723  0.500000  0.833333  0.502046  1.000000\n",
              "8724  0.833333  0.833333  0.710000  0.710000\n",
              "8725  0.833333  0.833333  0.790443  0.790443\n",
              "8726  0.833333  0.833333  1.000000  0.502046\n",
              "8727  0.833333  0.833333  0.502046  1.000000\n",
              "8728  0.500000  0.500000  0.880000  0.880000\n",
              "8729  0.500000  0.500000  0.961249  0.961249\n",
              "8730  0.500000  0.500000  1.000000  0.622254\n",
              "8731  0.500000  0.500000  0.622254  1.000000\n",
              "\n",
              "[8732 rows x 4 columns]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WNgaY_bUGw87",
        "colab_type": "text"
      },
      "source": [
        "# SSDクラスを実装"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "soAjguIYGw89",
        "colab_type": "code",
        "colab": {},
        "outputId": "10c682f5-84d2-4646-eb90-88fc739a6495"
      },
      "source": [
        "# SSDクラスを作成する\n",
        "class SSD(nn.Module):\n",
        "\n",
        "    def __init__(self, phase, cfg):\n",
        "        super(SSD, self).__init__()\n",
        "\n",
        "        self.phase = phase  # train or inferenceを指定\n",
        "        self.num_classes = cfg[\"num_classes\"]  # クラス数=21\n",
        "\n",
        "        # SSDのネットワークを作る\n",
        "        self.vgg = make_vgg()\n",
        "        self.extras = make_extras()\n",
        "        self.L2Norm = L2Norm()\n",
        "        self.loc, self.conf = make_loc_conf(\n",
        "            cfg[\"num_classes\"], cfg[\"bbox_aspect_num\"])\n",
        "\n",
        "        # DBox作成\n",
        "        dbox = DBox(cfg)\n",
        "        self.dbox_list = dbox.make_dbox_list()\n",
        "\n",
        "        # 推論時はクラス「Detect」を用意します\n",
        "        if phase == 'inference':\n",
        "            self.detect = Detect()\n",
        "\n",
        "\n",
        "# 動作確認\n",
        "ssd_test = SSD(phase=\"train\", cfg=ssd_cfg)\n",
        "print(ssd_test)\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "SSD(\n",
            "  (vgg): ModuleList(\n",
            "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (1): ReLU(inplace)\n",
            "    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (3): ReLU(inplace)\n",
            "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
            "    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (6): ReLU(inplace)\n",
            "    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (8): ReLU(inplace)\n",
            "    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
            "    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (11): ReLU(inplace)\n",
            "    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (13): ReLU(inplace)\n",
            "    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (15): ReLU(inplace)\n",
            "    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)\n",
            "    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (18): ReLU(inplace)\n",
            "    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (20): ReLU(inplace)\n",
            "    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (22): ReLU(inplace)\n",
            "    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
            "    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (25): ReLU(inplace)\n",
            "    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (27): ReLU(inplace)\n",
            "    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (29): ReLU(inplace)\n",
            "    (30): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
            "    (31): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))\n",
            "    (32): ReLU(inplace)\n",
            "    (33): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))\n",
            "    (34): ReLU(inplace)\n",
            "  )\n",
            "  (extras): ModuleList(\n",
            "    (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))\n",
            "    (1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
            "    (2): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))\n",
            "    (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
            "    (4): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
            "    (5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))\n",
            "    (6): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
            "    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))\n",
            "  )\n",
            "  (L2Norm): L2Norm()\n",
            "  (loc): ModuleList(\n",
            "    (0): Conv2d(512, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (1): Conv2d(1024, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (2): Conv2d(512, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (3): Conv2d(256, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (4): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (5): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  )\n",
            "  (conf): ModuleList(\n",
            "    (0): Conv2d(512, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (1): Conv2d(1024, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (2): Conv2d(512, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (3): Conv2d(256, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (4): Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (5): Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "  )\n",
            ")\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e2kmA27KGw9C",
        "colab_type": "text"
      },
      "source": [
        "# ここから2.5節 順伝搬の実装です"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4PiaKi9QGw9E",
        "colab_type": "text"
      },
      "source": [
        "# 関数decodeを実装する"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0hL7JLRyGw9F",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# オフセット情報を使い、DBoxをBBoxに変換する関数\n",
        "\n",
        "\n",
        "def decode(loc, dbox_list):\n",
        "    \"\"\"\n",
        "    オフセット情報を使い、DBoxをBBoxに変換する。\n",
        "\n",
        "    Parameters\n",
        "    ----------\n",
        "    loc:  [8732,4]\n",
        "        SSDモデルで推論するオフセット情報。\n",
        "    dbox_list: [8732,4]\n",
        "        DBoxの情報\n",
        "\n",
        "    Returns\n",
        "    -------\n",
        "    boxes : [xmin, ymin, xmax, ymax]\n",
        "        BBoxの情報\n",
        "    \"\"\"\n",
        "\n",
        "    # DBoxは[cx, cy, width, height]で格納されている\n",
        "    # locも[Δcx, Δcy, Δwidth, Δheight]で格納されている\n",
        "\n",
        "    # オフセット情報からBBoxを求める\n",
        "    boxes = torch.cat((\n",
        "        dbox_list[:, :2] + loc[:, :2] * 0.1 * dbox_list[:, 2:],\n",
        "        dbox_list[:, 2:] * torch.exp(loc[:, 2:] * 0.2)), dim=1)\n",
        "    # boxesのサイズはtorch.Size([8732, 4])となります\n",
        "\n",
        "    # BBoxの座標情報を[cx, cy, width, height]から[xmin, ymin, xmax, ymax] に\n",
        "    boxes[:, :2] -= boxes[:, 2:] / 2  # 座標(xmin,ymin)へ変換\n",
        "    boxes[:, 2:] += boxes[:, :2]  # 座標(xmax,ymax)へ変換\n",
        "\n",
        "    return boxes\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GRWdh3ZlGw9K",
        "colab_type": "text"
      },
      "source": [
        "# Non-Maximum Suppressionを行う関数を実装する"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wtndAFgxGw9M",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Non-Maximum Suppressionを行う関数\n",
        "\n",
        "\n",
        "def nm_suppression(boxes, scores, overlap=0.45, top_k=200):\n",
        "    \"\"\"\n",
        "    Non-Maximum Suppressionを行う関数。\n",
        "    boxesのうち被り過ぎ（overlap以上）のBBoxを削除する。\n",
        "\n",
        "    Parameters\n",
        "    ----------\n",
        "    boxes : [確信度閾値（0.01）を超えたBBox数,4]\n",
        "        BBox情報。\n",
        "    scores :[確信度閾値（0.01）を超えたBBox数]\n",
        "        confの情報\n",
        "\n",
        "    Returns\n",
        "    -------\n",
        "    keep : リスト\n",
        "        confの降順にnmsを通過したindexが格納\n",
        "    count：int\n",
        "        nmsを通過したBBoxの数\n",
        "    \"\"\"\n",
        "\n",
        "    # returnのひな形を作成\n",
        "    count = 0\n",
        "    keep = scores.new(scores.size(0)).zero_().long()\n",
        "    # keep：torch.Size([確信度閾値を超えたBBox数])、要素は全部0\n",
        "\n",
        "    # 各BBoxの面積areaを計算\n",
        "    x1 = boxes[:, 0]\n",
        "    y1 = boxes[:, 1]\n",
        "    x2 = boxes[:, 2]\n",
        "    y2 = boxes[:, 3]\n",
        "    area = torch.mul(x2 - x1, y2 - y1)\n",
        "\n",
        "    # boxesをコピーする。後で、BBoxの被り度合いIOUの計算に使用する際のひな形として用意\n",
        "    tmp_x1 = boxes.new()\n",
        "    tmp_y1 = boxes.new()\n",
        "    tmp_x2 = boxes.new()\n",
        "    tmp_y2 = boxes.new()\n",
        "    tmp_w = boxes.new()\n",
        "    tmp_h = boxes.new()\n",
        "\n",
        "    # socreを昇順に並び変える\n",
        "    v, idx = scores.sort(0)\n",
        "\n",
        "    # 上位top_k個（200個）のBBoxのindexを取り出す（200個存在しない場合もある）\n",
        "    idx = idx[-top_k:]\n",
        "\n",
        "    # idxの要素数が0でない限りループする\n",
        "    while idx.numel() > 0:\n",
        "        i = idx[-1]  # 現在のconf最大のindexをiに\n",
        "\n",
        "        # keepの現在の最後にconf最大のindexを格納する\n",
        "        # このindexのBBoxと被りが大きいBBoxをこれから消去する\n",
        "        keep[count] = i\n",
        "        count += 1\n",
        "\n",
        "        # 最後のBBoxになった場合は、ループを抜ける\n",
        "        if idx.size(0) == 1:\n",
        "            break\n",
        "\n",
        "        # 現在のconf最大のindexをkeepに格納したので、idxをひとつ減らす\n",
        "        idx = idx[:-1]\n",
        "\n",
        "        # -------------------\n",
        "        # これからkeepに格納したBBoxと被りの大きいBBoxを抽出して除去する\n",
        "        # -------------------\n",
        "        # ひとつ減らしたidxまでのBBoxを、outに指定した変数として作成する\n",
        "        torch.index_select(x1, 0, idx, out=tmp_x1)\n",
        "        torch.index_select(y1, 0, idx, out=tmp_y1)\n",
        "        torch.index_select(x2, 0, idx, out=tmp_x2)\n",
        "        torch.index_select(y2, 0, idx, out=tmp_y2)\n",
        "\n",
        "        # すべてのBBoxに対して、現在のBBox=indexがiと被っている値までに設定(clamp)\n",
        "        tmp_x1 = torch.clamp(tmp_x1, min=x1[i])\n",
        "        tmp_y1 = torch.clamp(tmp_y1, min=y1[i])\n",
        "        tmp_x2 = torch.clamp(tmp_x2, max=x2[i])\n",
        "        tmp_y2 = torch.clamp(tmp_y2, max=y2[i])\n",
        "\n",
        "        # wとhのテンソルサイズをindexを1つ減らしたものにする\n",
        "        tmp_w.resize_as_(tmp_x2)\n",
        "        tmp_h.resize_as_(tmp_y2)\n",
        "\n",
        "        # clampした状態でのBBoxの幅と高さを求める\n",
        "        tmp_w = tmp_x2 - tmp_x1\n",
        "        tmp_h = tmp_y2 - tmp_y1\n",
        "\n",
        "        # 幅や高さが負になっているものは0にする\n",
        "        tmp_w = torch.clamp(tmp_w, min=0.0)\n",
        "        tmp_h = torch.clamp(tmp_h, min=0.0)\n",
        "\n",
        "        # clampされた状態での面積を求める\n",
        "        inter = tmp_w*tmp_h\n",
        "\n",
        "        # IoU = intersect部分 / (area(a) + area(b) - intersect部分)の計算\n",
        "        rem_areas = torch.index_select(area, 0, idx)  # 各BBoxの元の面積\n",
        "        union = (rem_areas - inter) + area[i]  # 2つのエリアの和（OR）の面積\n",
        "        IoU = inter/union\n",
        "\n",
        "        # IoUがoverlapより小さいidxのみを残す\n",
        "        idx = idx[IoU.le(overlap)]  # leはLess than or Equal toの処理をする演算です\n",
        "        # IoUがoverlapより大きいidxは、最初に選んでkeepに格納したidxと同じ物体に対してBBoxを囲んでいるため消去\n",
        "\n",
        "    # whileのループが抜けたら終了\n",
        "\n",
        "    return keep, count\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hf1a7f4nGw9S",
        "colab_type": "text"
      },
      "source": [
        "# Detectクラスを実装する"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LFV7wK4GGw9T",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# SSDの推論時にconfとlocの出力から、被りを除去したBBoxを出力する\n",
        "\n",
        "\n",
        "class Detect(Function):\n",
        "\n",
        "    def __init__(self, conf_thresh=0.01, top_k=200, nms_thresh=0.45):\n",
        "        self.softmax = nn.Softmax(dim=-1)  # confをソフトマックス関数で正規化するために用意\n",
        "        self.conf_thresh = conf_thresh  # confがconf_thresh=0.01より高いDBoxのみを扱う\n",
        "        self.top_k = top_k  # nm_supressionでconfの高いtop_k個を計算に使用する, top_k = 200\n",
        "        self.nms_thresh = nms_thresh  # nm_supressionでIOUがnms_thresh=0.45より大きいと、同一物体へのBBoxとみなす\n",
        "\n",
        "    def forward(self, loc_data, conf_data, dbox_list):\n",
        "        \"\"\"\n",
        "        順伝搬の計算を実行する。\n",
        "\n",
        "        Parameters\n",
        "        ----------\n",
        "        loc_data:  [batch_num,8732,4]\n",
        "            オフセット情報。\n",
        "        conf_data: [batch_num, 8732,num_classes]\n",
        "            検出の確信度。\n",
        "        dbox_list: [8732,4]\n",
        "            DBoxの情報\n",
        "\n",
        "        Returns\n",
        "        -------\n",
        "        output : torch.Size([batch_num, 21, 200, 5])\n",
        "            （batch_num、クラス、confのtop200、BBoxの情報）\n",
        "        \"\"\"\n",
        "\n",
        "        # 各サイズを取得\n",
        "        num_batch = loc_data.size(0)  # ミニバッチのサイズ\n",
        "        num_dbox = loc_data.size(1)  # DBoxの数 = 8732\n",
        "        num_classes = conf_data.size(2)  # クラス数 = 21\n",
        "\n",
        "        # confはソフトマックスを適用して正規化する\n",
        "        conf_data = self.softmax(conf_data)\n",
        "\n",
        "        # 出力の型を作成する。テンソルサイズは[minibatch数, 21, 200, 5]\n",
        "        output = torch.zeros(num_batch, num_classes, self.top_k, 5)\n",
        "\n",
        "        # cof_dataを[batch_num,8732,num_classes]から[batch_num, num_classes,8732]に順番変更\n",
        "        conf_preds = conf_data.transpose(2, 1)\n",
        "\n",
        "        # ミニバッチごとのループ\n",
        "        for i in range(num_batch):\n",
        "\n",
        "            # 1. locとDBoxから修正したBBox [xmin, ymin, xmax, ymax] を求める\n",
        "            decoded_boxes = decode(loc_data[i], dbox_list)\n",
        "\n",
        "            # confのコピーを作成\n",
        "            conf_scores = conf_preds[i].clone()\n",
        "\n",
        "            # 画像クラスごとのループ（背景クラスのindexである0は計算せず、index=1から）\n",
        "            for cl in range(1, num_classes):\n",
        "\n",
        "                # 2.confの閾値を超えたBBoxを取り出す\n",
        "                # confの閾値を超えているかのマスクを作成し、\n",
        "                # 閾値を超えたconfのインデックスをc_maskとして取得\n",
        "                c_mask = conf_scores[cl].gt(self.conf_thresh)\n",
        "                # gtはGreater thanのこと。gtにより閾値を超えたものが1に、以下が0になる\n",
        "                # conf_scores:torch.Size([21, 8732])\n",
        "                # c_mask:torch.Size([8732])\n",
        "\n",
        "                # scoresはtorch.Size([閾値を超えたBBox数])\n",
        "                scores = conf_scores[cl][c_mask]\n",
        "\n",
        "                # 閾値を超えたconfがない場合、つまりscores=[]のときは、何もしない\n",
        "                if scores.nelement() == 0:  # nelementで要素数の合計を求める\n",
        "                    continue\n",
        "\n",
        "                # c_maskを、decoded_boxesに適用できるようにサイズを変更します\n",
        "                l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)\n",
        "                # l_mask:torch.Size([8732, 4])\n",
        "\n",
        "                # l_maskをdecoded_boxesに適応します\n",
        "                boxes = decoded_boxes[l_mask].view(-1, 4)\n",
        "                # decoded_boxes[l_mask]で1次元になってしまうので、\n",
        "                # viewで（閾値を超えたBBox数, 4）サイズに変形しなおす\n",
        "\n",
        "                # 3. Non-Maximum Suppressionを実施し、被っているBBoxを取り除く\n",
        "                ids, count = nm_suppression(\n",
        "                    boxes, scores, self.nms_thresh, self.top_k)\n",
        "                # ids：confの降順にNon-Maximum Suppressionを通過したindexが格納\n",
        "                # count：Non-Maximum Suppressionを通過したBBoxの数\n",
        "\n",
        "                # outputにNon-Maximum Suppressionを抜けた結果を格納\n",
        "                output[i, cl, :count] = torch.cat((scores[ids[:count]].unsqueeze(1),\n",
        "                                                   boxes[ids[:count]]), 1)\n",
        "\n",
        "        return output  # torch.Size([1, 21, 200, 5])\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MYaeLRx_Gw9Z",
        "colab_type": "text"
      },
      "source": [
        "# SSDクラスを実装する"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "55Sgd5HBGw9b",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# SSDクラスを作成する\n",
        "\n",
        "\n",
        "class SSD(nn.Module):\n",
        "\n",
        "    def __init__(self, phase, cfg):\n",
        "        super(SSD, self).__init__()\n",
        "\n",
        "        self.phase = phase  # train or inferenceを指定\n",
        "        self.num_classes = cfg[\"num_classes\"]  # クラス数=21\n",
        "\n",
        "        # SSDのネットワークを作る\n",
        "        self.vgg = make_vgg()\n",
        "        self.extras = make_extras()\n",
        "        self.L2Norm = L2Norm()\n",
        "        self.loc, self.conf = make_loc_conf(\n",
        "            cfg[\"num_classes\"], cfg[\"bbox_aspect_num\"])\n",
        "\n",
        "        # DBox作成\n",
        "        dbox = DBox(cfg)\n",
        "        self.dbox_list = dbox.make_dbox_list()\n",
        "\n",
        "        # 推論時はクラス「Detect」を用意します\n",
        "        if phase == 'inference':\n",
        "            self.detect = Detect()\n",
        "\n",
        "    def forward(self, x):\n",
        "        sources = list()  # locとconfへの入力source1～6を格納\n",
        "        loc = list()  # locの出力を格納\n",
        "        conf = list()  # confの出力を格納\n",
        "\n",
        "        # vggのconv4_3まで計算する\n",
        "        for k in range(23):\n",
        "            x = self.vgg[k](x)\n",
        "\n",
        "        # conv4_3の出力をL2Normに入力し、source1を作成、sourcesに追加\n",
        "        source1 = self.L2Norm(x)\n",
        "        sources.append(source1)\n",
        "\n",
        "        # vggを最後まで計算し、source2を作成、sourcesに追加\n",
        "        for k in range(23, len(self.vgg)):\n",
        "            x = self.vgg[k](x)\n",
        "\n",
        "        sources.append(x)\n",
        "\n",
        "        # extrasのconvとReLUを計算\n",
        "        # source3～6を、sourcesに追加\n",
        "        for k, v in enumerate(self.extras):\n",
        "            x = F.relu(v(x), inplace=True)\n",
        "            if k % 2 == 1:  # conv→ReLU→cov→ReLUをしたらsourceに入れる\n",
        "                sources.append(x)\n",
        "\n",
        "        # source1～6に、それぞれ対応する畳み込みを1回ずつ適用する\n",
        "        # zipでforループの複数のリストの要素を取得\n",
        "        # source1～6まであるので、6回ループが回る\n",
        "        for (x, l, c) in zip(sources, self.loc, self.conf):\n",
        "            # Permuteは要素の順番を入れ替え\n",
        "            loc.append(l(x).permute(0, 2, 3, 1).contiguous())\n",
        "            conf.append(c(x).permute(0, 2, 3, 1).contiguous())\n",
        "            # l(x)とc(x)で畳み込みを実行\n",
        "            # l(x)とc(x)の出力サイズは[batch_num, 4*アスペクト比の種類数, featuremapの高さ, featuremap幅]\n",
        "            # sourceによって、アスペクト比の種類数が異なり、面倒なので順番入れ替えて整える\n",
        "            # permuteで要素の順番を入れ替え、\n",
        "            # [minibatch数, featuremap数, featuremap数,4*アスペクト比の種類数]へ\n",
        "            # （注釈）\n",
        "            # torch.contiguous()はメモリ上で要素を連続的に配置し直す命令です。\n",
        "            # あとでview関数を使用します。\n",
        "            # このviewを行うためには、対象の変数がメモリ上で連続配置されている必要があります。\n",
        "\n",
        "        # さらにlocとconfの形を変形\n",
        "        # locのサイズは、torch.Size([batch_num, 34928])\n",
        "        # confのサイズはtorch.Size([batch_num, 183372])になる\n",
        "        loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)\n",
        "        conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)\n",
        "\n",
        "        # さらにlocとconfの形を整える\n",
        "        # locのサイズは、torch.Size([batch_num, 8732, 4])\n",
        "        # confのサイズは、torch.Size([batch_num, 8732, 21])\n",
        "        loc = loc.view(loc.size(0), -1, 4)\n",
        "        conf = conf.view(conf.size(0), -1, self.num_classes)\n",
        "\n",
        "        # 最後に出力する\n",
        "        output = (loc, conf, self.dbox_list)\n",
        "\n",
        "        if self.phase == \"inference\":  # 推論時\n",
        "            # クラス「Detect」のforwardを実行\n",
        "            # 返り値のサイズは torch.Size([batch_num, 21, 200, 5])\n",
        "            return self.detect(output[0], output[1], output[2])\n",
        "\n",
        "        else:  # 学習時\n",
        "            return output\n",
        "            # 返り値は(loc, conf, dbox_list)のタプル\n",
        "\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vLqzjJLEGw9e",
        "colab_type": "text"
      },
      "source": [
        "以上"
      ]
    }
  ]
}